from .reward_function_base import BaseRewardFunction

class EventDrivenReward(BaseRewardFunction):
    """
    - 如果代理被导弹击落: -200
    - 如果代理意外坠毁: -200
    - 如果代理发射的导弹成功击落其他飞机: +200
    """
    def __init__(self, config):
        super().__init__(config)
        
    def get_reward(self, task, env, agent_id):
        """
        奖励是所有事件奖励的和
        """
        reward = 0
        if env.agents[agent_id].is_shotdown:
            reward -= 200
        elif env.agents[agent_id].is_crash:
            reward -= 200
        for missile in env.agents[agent_id].launch_missiles:
            if missile.is_success:
                reward += 200
        return self._process(reward, agent_id)